"""Copy and pasted from https://github.com/pfnet/pfrl/blob/master/pfrl/wrappers/normalize_action_space.py

except that it uses gymnasium rather than gym
"""


import gymnasium as gymn
import gymnasium.spaces
import numpy as np
from rpi import logger


class NormalizeActionSpace(gymn.ActionWrapper):
    """Normalize a Box action space to [-1, 1]^n."""

    def __init__(self, env):
        super().__init__(env)
        assert isinstance(env.action_space, gymn.spaces.Box)
        self.action_space = gymn.spaces.Box(
            low=-np.ones_like(env.action_space.low),
            high=np.ones_like(env.action_space.low),
        )

    def action(self, action):
        # action is in [-1, 1]
        action = action.copy()

        # -> [0, 2]
        action += 1

        # -> [0, orig_high - orig_low]
        action *= (self.env.action_space.high - self.env.action_space.low) / 2

        # -> [orig_low, orig_high]
        return action + self.env.action_space.low


class GymnasiumStepWrapper(gymn.Wrapper):
    """Change the return values of Gymnasium to follow the original gym"""
    def __init__(self, env):
        super().__init__(env)

    def reset(self, *args, **kwargs):
        obs, info = self.env.reset()
        return obs

    def step(self, obs):
        observation, reward, terminated, truncated, info = self.env.step(obs)
        if truncated:
            # pfrl's way to tell that the trajectory is truncated
            info['needs_reset'] = True

        return observation, reward, terminated, info


class ReduceObsWrapper(gymn.ObservationWrapper):
    def __init__(self, env, obs_cutoff_idx):
        super().__init__(env)
        self.obs_cutoff_idx = obs_cutoff_idx

        assert len(self.observation_space.low.shape) == 1
        self.observation_space = gymn.spaces.Box(
            low=0,
            high=255,
            shape=(obs_cutoff_idx, ),
            dtype="uint8",
        )

    def observation(self, obs):
        return obs[:self.obs_cutoff_idx]
